{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67337d44",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pyceres\n",
    "import pycolmap\n",
    "import numpy as np\n",
    "from hloc.utils import viz_3d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc59c6d8",
   "metadata": {},
   "source": [
    "## Synthetic reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a155e6d8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def create_reconstruction(num_points=50, num_images=2, seed=3):\n",
    "    state = np.random.RandomState(seed)\n",
    "    rec = pycolmap.Reconstruction()\n",
    "    p3d = state.uniform(-1, 1, (num_points, 3)) + np.array([0, 0, 3])\n",
    "    for p in p3d:\n",
    "        rec.add_point3D(p, pycolmap.Track(), np.zeros(3))\n",
    "    w, h = 640, 480\n",
    "    cam = pycolmap.Camera(model='SIMPLE_PINHOLE', width=w, height=h, params=np.array([max(w,h)*1.2, w/2, h/2]), id=0)\n",
    "    rec.add_camera(cam)\n",
    "    for i in range(num_images):\n",
    "        im = pycolmap.Image(id=i, name=str(i), camera_id=cam.camera_id, tvec=state.uniform(-1, 1, 3))\n",
    "        im.registered = True\n",
    "        p2d = cam.world_to_image(im.project(list(rec.points3D.values())))\n",
    "        p2d_obs = np.array(p2d) + state.randn(len(p2d), 2)\n",
    "        im.points2D = pycolmap.ListPoint2D([pycolmap.Point2D(p, id_) for p, id_ in zip(p2d_obs, rec.points3D)])\n",
    "        rec.add_image(im)\n",
    "    return rec\n",
    "\n",
    "rec_gt = create_reconstruction()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cee46c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = viz_3d.init_figure()\n",
    "viz_3d.plot_reconstruction(fig, rec_gt, min_track_length=0, color='rgb(255,0,0)')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98a6db79",
   "metadata": {},
   "source": [
    "## Optimize 3D points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e199af72",
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_problem(rec):\n",
    "    prob = pyceres.Problem()\n",
    "    loss = pyceres.TrivialLoss()\n",
    "    for im in rec.images.values():\n",
    "        cam = rec.cameras[im.camera_id]\n",
    "        for p in im.points2D:\n",
    "            cost = pyceres.factors.BundleAdjustmentCost(cam.model_id, p.xy, im.qvec, im.tvec)\n",
    "            prob.add_residual_block(cost, loss, [rec.points3D[p.point3D_id].xyz, cam.params])\n",
    "    for cam in rec.cameras.values():\n",
    "        prob.set_parameter_block_constant(cam.params)\n",
    "    return prob\n",
    "\n",
    "def solve(prob):\n",
    "    print(prob.num_parameter_blocks(), prob.num_parameters(), prob.num_residual_blocks(), prob.num_residuals())\n",
    "    options = pyceres.SolverOptions()\n",
    "    options.linear_solver_type = pyceres.LinearSolverType.DENSE_QR\n",
    "    options.minimizer_progress_to_stdout = True\n",
    "    options.num_threads = -1\n",
    "    summary = pyceres.SolverSummary()\n",
    "    pyceres.solve(options, prob, summary)\n",
    "    print(summary.BriefReport())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac2305ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec = create_reconstruction()\n",
    "problem = define_problem(rec)\n",
    "solve(problem)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43df4cd1",
   "metadata": {},
   "source": [
    "Add some noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "356a7ad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec = create_reconstruction()\n",
    "for p in rec.points3D.values():\n",
    "    p.xyz += np.random.RandomState(0).uniform(-0.5, 0.5, 3)\n",
    "print(rec.points3D[1].xyz)\n",
    "problem = define_problem(rec)\n",
    "solve(problem)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e24927e",
   "metadata": {},
   "source": [
    "## Optimize poses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d9aca7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_problem2(rec):\n",
    "    prob = pyceres.Problem()\n",
    "    loss = pyceres.TrivialLoss()\n",
    "    for im in rec.images.values():\n",
    "        cam = rec.cameras[im.camera_id]\n",
    "        for p in im.points2D:\n",
    "            cost = pyceres.factors.BundleAdjustmentCost(cam.model_id, p.xy)            \n",
    "            prob.add_residual_block(cost, loss, [im.qvec, im.tvec, rec.points3D[p.point3D_id].xyz, cam.params])\n",
    "        prob.set_manifold(im.qvec, pyceres.QuaternionManifold())\n",
    "    for cam in rec.cameras.values():\n",
    "        prob.set_parameter_block_constant(cam.params)\n",
    "    for p in rec.points3D.values():\n",
    "        prob.set_parameter_block_constant(p.xyz)\n",
    "    return prob\n",
    "\n",
    "rec = create_reconstruction()\n",
    "for im in rec.images.values():\n",
    "    im.tvec += np.random.randn(3)/2\n",
    "print([np.linalg.norm(rec.images[i].tvec - rec_gt.images[i].tvec) for i in rec.images])\n",
    "problem = define_problem2(rec)\n",
    "solve(problem)\n",
    "print([np.linalg.norm(rec.images[i].tvec - rec_gt.images[i].tvec) for i in rec.images])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31196ca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(rec.cameras[0].params, rec_gt.cameras[0].params)\n",
    "for i in rec.images:\n",
    "    print(rec.images[i].tvec, rec_gt.images[i].tvec)\n",
    "    print(rec.images[i].qvec, rec_gt.images[i].qvec)\n",
    "rec.points3D[1].xyz, rec_gt.points3D[1].xyz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aab4bcbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = viz_3d.init_figure()\n",
    "viz_3d.plot_reconstruction(fig, rec_gt, min_track_length=0, color='rgb(255,0,0)')\n",
    "viz_3d.plot_reconstruction(fig, rec, min_track_length=0, color='rgb(0,255,0)')\n",
    "fig.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
